# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import division

import hashlib
import operator

import numpy as np
import pandas as pd
import pyarrow as pa

from petastorm.cache import NullCache
from petastorm.workers_pool.worker_base import WorkerBase
from petastorm.arrow_reader_worker import ArrowReaderWorkerResultsQueueReader


class ArrowCarbonReaderWorker(WorkerBase):
  def __init__(self, worker_id, publish_func, args):
    super(ArrowCarbonReaderWorker, self).__init__(worker_id, publish_func, args)

    self._filesystem = args[0]
    self._dataset_path = args[1]
    self._schema = args[2]
    self._ngram = args[3]
    self._split_pieces = args[4]
    self._local_cache = args[5]
    self._transform_spec = args[6]

    if self._ngram:
      raise NotImplementedError('ngrams are not supported by ArrowReaderWorker')

    # We create datasets lazily in the first invocation of 'def process'. This speeds up startup time since
    # all Worker constructors are serialized
    self._dataset = None

  @staticmethod
  def new_results_queue_reader():
    return ArrowReaderWorkerResultsQueueReader()

  # pylint: disable=arguments-differ
  def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
    """Main worker function. Loads and returns all rows matching the predicate from a blocklet

    Looks up the requested piece (a single row-group in a carbon file). If a predicate is specified,
    columns needed by the predicate are loaded first. If no rows in the blocklet matches the predicate criteria
    the rest of the columns are not loaded.

    :param piece_index:
    :param shuffle_row_drop_partition: A tuple 2 of the current row drop partition and the total number
        of partitions.
    :return:
    """

    piece = self._split_pieces[piece_index]

    if not isinstance(self._local_cache, NullCache):
      if worker_predicate:
        raise RuntimeError('Local cache is not supported together with predicates, '
                           'unless the dataset is partitioned by the column the predicate operates on.')
      if shuffle_row_drop_partition[1] != 1:
        raise RuntimeError('Local cache is not supported together with shuffle_row_drop_partitions > 1')

    if worker_predicate:
      all_cols = self._load_rows_with_predicate(piece, worker_predicate, shuffle_row_drop_partition)
    else:
      # Using hash of the dataset path with the relative path in order to:
      #  1. Make sure if a common cache serves multiple processes (e.g. redis), we don't have conflicts
      #  2. Dataset path is hashed, to make sure we don't create too long keys, which maybe incompatible with
      #     some cache implementations
      #  3. Still leave relative path and the piece_index in plain text to make it easier to debug
      cache_key = '{}:{}:{}'.format(hashlib.md5(self._dataset_path.encode('utf-8')).hexdigest(),
                                    piece.path, piece_index)
      all_cols = self._local_cache.get(cache_key,
                                       lambda: self._load_rows(piece, shuffle_row_drop_partition))

    if all_cols:
      self.publish_func(all_cols)

  def _load_rows(self, piece, shuffle_row_drop_range):
    """Loads all rows from a piece"""

    # pyarrow would fail if we request a column names that the dataset is partitioned by, so we strip them from
    # the `columns` argument.
    # partitions = self._dataset.partitions
    column_names_in_schema = list(field.name for field in self._schema.fields.values())
    # column_names = column_names_in_schema - partitions.partition_names

    result = self._read_with_shuffle_row_drop(piece, column_names_in_schema, shuffle_row_drop_range)

    if self._transform_spec:
      result = pa.Table.from_pandas(self._transform_spec.func(result.to_pandas()), preserve_index=False)

    return result

  def _load_rows_with_predicate(self, piece, worker_predicate, shuffle_row_drop_partition):
    """Loads all rows that match a predicate from a piece"""

    # 1. Read all columns needed by predicate
    # 2. Apply the predicate. If nothing matches, exit early
    # 3. Read the remaining columns

    # Split all column names into ones that are needed by predicateand the rest.
    predicate_column_names = set(worker_predicate.get_fields())

    if not predicate_column_names:
      raise ValueError('At least one field name must be returned by predicate\'s get_field() method')

    all_schema_names = set(field.name for field in self._schema.fields.values())

    invalid_column_names = predicate_column_names - all_schema_names
    if invalid_column_names:
      raise ValueError('At least some column names requested by the predicate ({}) '
                       'are not valid schema names: ({})'.format(', '.join(invalid_column_names),
                                                                 ', '.join(all_schema_names)))

    # Split into 'columns for predicate evaluation' and 'other columns'. We load 'other columns' only if at
    # least one row in the blocklet matched the predicate
    other_column_names = all_schema_names - predicate_column_names

    # Read columns needed for the predicate
    predicate_column_names_list = list(predicate_column_names)
    predicates_table = self._read_with_shuffle_row_drop(piece, predicate_column_names_list,
                                                        shuffle_row_drop_partition)

    predicates_data_frame = predicates_table.to_pandas()

    match_predicate_mask = worker_predicate.do_include(predicates_data_frame)
    erase_mask = match_predicate_mask.map(operator.not_)

    # Don't have anything left after filtering? Exit early.
    if erase_mask.all():
      return []

    predicates_data_frame[erase_mask] = None

    if other_column_names:
      # Read remaining columns
      other_column_names_list = list(other_column_names)
      other_table = self._read_with_shuffle_row_drop(piece, other_column_names_list,
                                                     shuffle_row_drop_partition)
      other_data_frame = other_table.to_pandas()
      other_data_frame[erase_mask] = None

      # Partition-by columns will appear in both other and predicate data frames. Deduplicate.
      columns_from_predicates = predicates_data_frame.columns.difference(other_data_frame.columns)
      result_data_frame = pd.merge(predicates_data_frame[columns_from_predicates], other_data_frame,
                                   copy=False, left_index=True, right_index=True)
    else:
      result_data_frame = predicates_data_frame

    result = result_data_frame[match_predicate_mask]

    if self._transform_spec:
      result = self._transform_spec.func(result)

    return pa.Table.from_pandas(result, preserve_index=False)

  def _read_with_shuffle_row_drop(self, piece, column_names, shuffle_row_drop_partition):
    table = piece.read_all(
      columns=column_names,
    )

    num_rows = len(table)
    num_partitions = shuffle_row_drop_partition[1]
    this_partition = shuffle_row_drop_partition[0]

    if num_partitions > 1:
      data_frame_pandas = table.to_pandas()
      partition_indexes = np.floor(np.arange(num_rows) / (float(num_rows) / min(num_rows, num_partitions)))

      table = pa.Table.from_pandas(data_frame_pandas.loc[partition_indexes == this_partition],
                                   preserve_index=False)

    return table
